def solve(matrix):
    m,n=len(matrix),len(matrix[0])
    row,col=[False]*m,[False]*n 
    for i in range(m):
        for j in range(n):
            if matrix[i][j]==0: row[i],col[j]=True,True 
    for i in range(m):
        for j in range(n):
            if row[i] or col[j]:matrix[i][j]=0 

matrix=[[1,1,1],[1,0,1],[1,1,1]]
solve(matrix)
print(matrix)
matrix=[[0,1,2,0],[3,4,5,2],[1,3,1,5]]
solve(matrix)
print(matrix)
    